# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, no-member

"""Executable object for TVM Runtime"""
from typing import Any, Callable, Dict, List, Optional

import tvm

from tvm.contrib import utils as _utils
from . import PackedFunc, Module


class Executable:
    """The executable object generated by `tvm.compile`."""

    def __init__(self, mod: Module):
        """Initialize the Executable object."""
        self.mod: Module = mod
        self._jitted_mod: Optional[Module] = None

    def __getitem__(self, name: str) -> PackedFunc:
        """Get the PackedFunc from the jitted module."""
        return self.jit().get_function(name, query_imports=True)

    def __call__(self, *args, **kwargs) -> Any:
        """Call the executable."""
        return self.jit().main(*args, **kwargs)

    def jit(
        self,
        *,
        fcompile: Optional[Callable[[str, List[str], Dict[str, Any]], None]] = None,
        addons: Optional[List[str]] = None,
        force_recompile: bool = False,
        **kwargs,
    ) -> Module:
        """Just-in-time compile and link the modules.

        The Executable returned by tvm.compile may not be directly
        runnable as they may contain cuda source files and objects that
        are yet to be compiled and linked.
        This function helps to create a runtime.Module for these cases.

        Parameters
        ----------
        fcompile : function(target, file_list, kwargs), optional
            The compilation function to use create the final library object during

        addons : list of str, optional
            Additional object files to link against.

        force_recompile : bool, optional
            If True, force a recompile of the module.

        kwargs : dict, optional
            Additional arguments passed to fcompile

        Returns
        -------
        rt_mod: tvm.runtime.Module
            A runnable runtime module that can be passed to VirtualMachine.

        Examples
        --------
        .. code:: python

            ex = tvm.compile(mod, target)
            rt_mod = ex.jit()

        """

        # If the module is already jitted and we don't want to force a recompile,
        # return the cached module
        if self._jitted_mod is not None and not force_recompile:
            return self._jitted_mod

        # TODO(tvm-team): Update runtime.Module interface
        # to query these properties as bitmask.
        def _not_runnable(x):
            return x.kind in ("c", "static_library")

        # pylint:disable = protected-access
        not_runnable_list = self.mod._collect_from_import_tree(_not_runnable)

        # everything is runnable, directly return mod.
        if len(not_runnable_list) == 0:
            return self.mod

        # found source module, or other not runnable modules need to be export and load
        # TODO(tvm-team): Support runnable but not exportable module.
        # by collecting the link and allow export_library skip those modules.
        workspace_dir = _utils.tempdir()
        dso_path = workspace_dir.relpath("exported.so")
        self.export_library(dso_path, fcompile=fcompile, addons=addons, **kwargs)
        self._jitted_mod = tvm.runtime.load_module(dso_path)
        return self._jitted_mod

    def export_library(
        self,
        file_name,
        *,
        fcompile=None,
        addons=None,
        workspace_dir=None,
        **kwargs,
    ):
        """
        Export the module and all imported modules into a single device library.

        This function only works on host LLVM modules, other runtime::Module
        subclasses will work with this API but they must support implement
        the save and load mechanisms of modules completely including saving
        from streams and files. This will pack your non-shared library module
        into a single shared library which can later be loaded by TVM.

        Parameters
        ----------
        file_name : str
            The name of the shared library.

        fcompile : function(target, file_list, kwargs), optional
            The compilation function to use create the final library object during
            export.

            For example, when fcompile=_cc.create_shared, or when it is not supplied but
            module is "llvm," this is used to link all produced artifacts
            into a final dynamic library.

            This behavior is controlled by the type of object exported.
            If fcompile has attribute object_format, will compile host library
            to that format. Otherwise, will use default format "o".

        addons : list of str, optional
            Additional object files to link against.

        workspace_dir : str, optional
            The path of the directory used to create the intermediate
            artifacts when exporting the module.
            If this is not provided a temporary dir will be created.

        kwargs : dict, optional
            Additional arguments passed to fcompile

        Returns
        -------
        result of fcompile()  : unknown, optional
            If the compilation function returns an artifact it would be returned via
            export_library, if any.
        """
        return self.mod.export_library(
            file_name,
            fcompile=fcompile,
            addons=addons,
            workspace_dir=workspace_dir,
            **kwargs,
        )
